package src

import (
	"encoding/json"
	"fmt"

	ge "gitee.com/haifengat/gin-ex/v2"
	"gitee.com/haifengat/goctp/v2"
	zd "gitee.com/haifengat/zorm-dm/v2"
	"github.com/gin-gonic/gin"
	"github.com/pkg/errors"
)

var subInstrument = map[string]struct{}{}

func Route(g *gin.RouterGroup) {
	// initEnv() // 与 run24 重复执行时导致错误
	initSSE(g) // 实时更新

	// 订阅
	g.POST("/subscript", ge.CreateUpHandler(postSubscript))
	// 取消
	g.POST("/unsubscript", ge.CreateUpHandler(postUnsubscript))

	// ================ 交易相关 ==============
	// 接口状态
	g.GET("/login", ge.CreateQryHandler(func(ctx *gin.Context) (data any, total int, errRtn error) {
		data = map[string]any{"is_login": trd != nil && isLogin}
		return
	}))
	// IDS
	g.GET("/ids/:tblName", ge.CreateQryHandler(getIDs))

	// 行情: 最新
	g.POST("/latest", ge.CreateQryHandler(getLatest))

	// 查询数据
	g.POST("/get/:tblName", ge.CreateQryHandler(postGet))

	// 下单
	g.POST("/order", ge.CreateUpHandler(func(ctx *gin.Context) (errRtn error) {
		if trd == nil || !isLogin {
			errRtn = errors.New("交易未初始化")
			return
		}
		var order OrderInsert
		errRtn = ctx.ShouldBindJSON(&order)
		if errRtn != nil {
			return
		}
		_, rsp := trd.ReqOrderInsertLimit(string(order.InstrumentID), goctp.TThostFtdcDirectionType(order.Director[0]), goctp.TThostFtdcOffsetFlagType(order.Offset[0]), order.Price, order.Volume)
		if rsp.ErrorID != 0 {
			errRtn = errors.New(rsp.ErrorMsg.String())
			return
		}
		return
	}))
	g.POST("/cancel", ge.CreateUpHandler(func(ctx *gin.Context) (errRtn error) {
		mp, err := ge.PostParams(ctx)
		if err != nil {
			errRtn = errors.Wrap(err, "参数错误")
			return
		}
		if len(mp) > 0 {
			if id, ok := mp["order_id"]; ok {
				if trd.ReqOrderAction(id.(string)) == -9 {
					errRtn = errors.New("未有对应的委托: " + id.(string))
					return
				}
			}
		}
		return
	}))
}

func postSubscript(ctx *gin.Context) (errRtn error) {
	if md == nil {
		errRtn = errors.New("行情未初始化")
		return
	}
	mp, err := ge.PostParams(ctx)
	if err != nil {
		errRtn = errors.Wrap(err, "参数错误")
		return
	}
	if len(mp) > 0 {
		if tmp, ok := mp["InstrumentID"]; ok { // instrument:['a','b'] 数组
			if inst, ok := tmp.(string); ok {
				if _, ok := trd.Instruments[inst]; ok { // 正确的合约
					subInstrument[inst] = struct{}{}
					if tick, ok := md.Ticks[inst]; ok {
						sseTick(&tick)
					}
				}
			}
		}
	}
	return
}

func postUnsubscript(ctx *gin.Context) (errRtn error) {
	mp, err := ge.PostParams(ctx)
	if err != nil {
		errRtn = errors.Wrap(err, "参数错误")
		return
	}
	if len(mp) > 0 {
		if tmp, ok := mp["instrument_id"]; ok { // instrument:['a','b'] 数组
			for _, v := range tmp.([]any) {
				delete(subInstrument, v.(string))
			}
		}
	}
	return
}

func getLatest(ctx *gin.Context) (data any, total int, errRtn error) {
	where, _, _, _, err := ge.PostGetParams(ctx)
	if err != nil {
		errRtn = errors.Wrap(err, "查询参数")
		return
	}
	if inst, ok := where["instrument_id"]; ok {
		jsons, err := rdb.LRange(ctxRedis, inst.(string), 0, -1).Result()
		if err != nil {
			errRtn = err
		} else {
			bars := make([]Bar, 0)
			for _, v := range jsons {
				var bar = Bar{}
				err = json.Unmarshal([]byte(v), &bar)
				if err != nil {
					errRtn = err
					break
				}
				bars = append(bars, *bar.Fix())
			}
			data = bars
			total = len(bars)
		}
	}
	return
}

// postGet POST 查询数据
//
//	@param ctx
//	@return data
//	@return total
//	@return errRtn
func postGet(ctx *gin.Context) (data any, total int, errRtn error) {
	var (
		cur, size int
		append    string
		p         *zd.PageInfo = nil
		where     map[string]any
	)
	tblName := ctx.Param("tblName")
	where, cur, size, append, errRtn = ge.PostGetParams(ctx)
	if errRtn != nil {
		errRtn = errors.Wrap(errRtn, "查询参数")
		return
	}
	if cur > 0 {
		p = &zd.PageInfo{PageNo: cur, PageSize: size}
	}

	switch tblName {
	case "InstrumentField":
		data, errRtn = zd.SelectMap[InstrumentField](ctxDAO, p, where, append)
	case "Bar":
		data, errRtn = zd.SelectMap[Bar](ctxDAO, p, where, append)
	case "AccountField":
		data, errRtn = zd.SelectMap[AccountField](ctxDAO, p, where, append)
	case "OrderField":
		if _, ok := where["TradingDay"]; !ok && len(tradingDay) > 0 {
			where["TradingDay"] = tradingDay
		}
		data, errRtn = zd.SelectMap[OrderField](ctxDAO, p, where, append)
	case "TradeField":
		if _, ok := where["TradingDay"]; !ok && len(tradingDay) > 0 {
			where["TradingDay"] = tradingDay
		}
		data, errRtn = zd.SelectMap[TradeField](ctxDAO, p, where, append)
	case "PositionField":
		if _, ok := where["TradingDay"]; !ok && len(tradingDay) > 0 {
			where["TradingDay"] = tradingDay
		}
		data, errRtn = zd.SelectMap[PositionField](ctxDAO, p, where, append)
	default:
		errRtn = errors.New("未知表名: " + tblName)
	}
	if p != nil {
		total = p.TotalCount
	}
	if errRtn != nil {
		errRtn = errors.Wrap(errRtn, "查询 "+tblName)
	}
	return
}

// getIDs 取表名对应的 主ID
//
//	@param ctx
//	@return data
//	@return total
//	@return errRtn
func getIDs(ctx *gin.Context) (data any, total int, errRtn error) {
	tblName := ctx.Param("tblName")
	var (
		mps       []map[string]any
		err       error
		columnsID string
	)
	where, _, _, _, _ := ge.QryParams(ctx)
	switch tblName {
	case "AccountField":
		columnsID = "AccountID"
		filter, _ := zd.FixQryParams[AccountField](ctxDAO, where) // 修正查询参数
		mps, err = zd.SelectMapColumns[AccountField](ctxDAO, nil, filter, []string{fmt.Sprintf("Distinct %s ID", columnsID)})
	case "InstrumentField":
		columnsID = "InstrumentID"
		filter, _ := zd.FixQryParams[InstrumentField](ctxDAO, where) // 修正查询参数
		mps, err = zd.SelectMapColumns[InstrumentField](ctxDAO, nil, filter, []string{fmt.Sprintf("Distinct %s AS ID", columnsID)})
	}
	if err != nil {
		errRtn = err
	}
	ids := make([]string, 0)
	for _, mp := range mps {
		for _, id := range mp {
			ids = append(ids, id.(string))
		}
	}
	data = ids
	total = len(ids)
	return
}
